#pragma once

#include "RTreeNodeLeaf.hpp"

namespace zzz{
//Randomized Tree
//To classify T
template <typename T,int RTREE_NODE_SPLIT=2>
class RTree
{
public:
  RTree(typename RTreeFunction<T>::RTreeFuncGenerator funcgen):ntag_(0),funcgen_(funcgen)
  {
    head_=new RTreeLeaf<T,RTREE_NODE_SPLIT>(0,0,funcgen_);
  }
  ~RTree(void)
  {
    delete head_;
  }

  //train and classify

  void GrowTrain(const T *v, const int tag){head_->GrowTrain(v,tag);}
  void Grow()
  {
    head_->UpdateP();
    RTreeNode<T,RTREE_NODE_SPLIT> *newnode=((RTreeLeaf<T,RTREE_NODE_SPLIT>*)head_)->Grow();
    delete head_;
    head_=newnode;
  }
  void ClearData(const int ntag){ntag_=ntag;head_->ClearData(ntag);}

  void Train(const T *v, const int tag){head_->Train(v,tag);}
  const double *Classify(const T *v) const {return head_->Classify(v);}
  void UpdateP(){head_->UpdateP();}

  //debug
  void Print()
  {
    zout<<"|-\n";
    head_->Print("| ");
  }
//private:
  RTreeNodeBase<T> *head_;
  int ntag_;
  int ntrained_;
  int height_;
  typename RTreeFunction<T>::RTreeFuncGenerator funcgen_;
};

//Randomized Multiple Tree structure
template <typename T,int RTREE_NODE_SPLIT=2>
class RTrees
{
public:
  RTrees(typename RTreeFunction<T>::RTreeFuncGenerator funcgen):ntag_(0),ntree_(0),trees_(NULL),funcgen_(funcgen)
  {
  }

  ~RTrees()
  {
    for (int i=0; i<ntree_; i++)
      delete trees_[i];
    delete trees_;
  }

  void BuildTrees(const int ntree,vector<pair<const T*,int> > &traindata,const int trainntag)
  {
    ntag_=trainntag;

    for (int i=0; i<ntree_; i++)
      delete trees_[i];
    delete trees_;
    ntree_=ntree;

    trees_=new RTree<T,RTREE_NODE_SPLIT>*[ntree_];
    ntree_=ntree;
    for (int i=0; i<ntree; i++)
    {
      zout<<"BuildTree: "<<i<<endl;
      trees_[i]=new RTree<T,RTREE_NODE_SPLIT>(funcgen_);
      trees_[i]->ClearData(trainntag);
      for (size_t j=0; j<traindata.size(); j++)
        trees_[i]->GrowTrain(traindata[j].first,traindata[j].second);
      trees_[i]->Grow();
    }
  }

  void ClearData(const int ntag)
  {
    ntag_=ntag;
    for (int i=0; i<ntree_; i++)
      trees_[i]->ClearData(ntag);
  }

  int Classify(const T *v, double *possibility) const
  {
    double *p=new double[ntag_];
    memset(p,0,sizeof(double)*ntag_);
    for (int i=0; i<ntree_; i++)
    {
      const double *thisp=trees_[i]->Classify(v);
      for (int j=0; j<ntag_; j++) p[j]+=thisp[j];
    }
    double maxp=0;
    int maxi=0;
    for (int i=0; i<ntag_; i++)
    {
      if (p[i]>maxp)
      {
        maxp=p[i];
        maxi=i;
      }
    }
    *possibility=maxp;
    delete[] p;
    return maxi;
  }

  void Train(const T *v, const int tag)
  {
    for (int i=0; i<ntree_; i++)
      trees_[i]->Train(v,tag);
  }

  void Train( vector<pair<const T*,int> > &traindata)
  {
    for (size_t i=0; i<traindata.size(); i++)
      Train(traindata[i].first,traindata[i].second);
    UpdateP();
  }

  void UpdateP()
  {
    for (int i=0; i<ntree_; i++) 
      trees_[i]->UpdateP();
  }

  void Print()
  {
    for (int i=0; i<ntree_; i++)
    {
      zout<<"Tree: "<<i<<endl;
      trees_[i]->Print();
    }
  }
  RTree<T,RTREE_NODE_SPLIT> **trees_;
  int ntree_;
  int ntag_;
  typename RTreeFunction<T>::RTreeFuncGenerator funcgen_;
};
}